Model improvements

Regularizing priors, centering, and normalization

Elizabeth King
Kevin Middleton

Improving sampling

  • Check for errors in the model
    • Particularly if you are writing .stan files
  • Determine better priors (use your knowledge)
  • Redefine model (centering, standardizing)
  • Ask stan to work harder (adapt_delta, max_treedepth, stepsize)

Divergences

Warning: 2070 of 20000 (10.0%) transitions ended with a divergence.

  • HMC particle simulation fails in an unexpected way
  • You (mostly) don’t need to understand why
    • You should be concerned about divergences
  • Taming Divergences in Stan Models by Martin Modrák

Divergences

The amazing thing about divergences is that what is essentially a numerical problem actually signals a wide array of possibly severe modelling problems. Be glad - few algorithms (in any area) have such a clear signal that things went wrong. – Martin Modrák

Explore and (hopefully) eliminate divergences

stan control parameters

  1. adapt_delta
    • Defaults to 0.8
    • Values closer to 1 (0.9, 0.95, 0.99) decrease the step size and reduce divergences
  2. max_treedepth
    • Warnings about reaching maximum tree depth
    • Increase the number of simulation steps

Both slow down sampling.

Multilevel data

D <- abdData::Zooplankton |> 
  mutate(treatment = fct_inorder(treatment),
         block = factor(block))
D
   treatment zooplankton block
1    control         4.1     1
2        low         2.2     1
3       high         1.3     1
4    control         3.2     2
5        low         2.4     2
6       high         2.0     2
7    control         3.0     3
8        low         1.5     3
9       high         1.0     3
10   control         2.3     4
11       low         1.3     4
12      high         1.0     4
13   control         2.5     5
14       low         2.6     5
15      high         1.6     5

Multilevel data

Multilevel model

fm <- brm(zooplankton ~ treatment - 1 + (1 | block),
          data = D,
          seed = 4547359,
          iter = 5e3, chains = 4, cores = 4)
Warning messages:
1: There were 45 divergent transitions after warmup. See
https://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
to find out why this is a problem and how to eliminate them. 
2: Examine the pairs() plot to diagnose sampling problems

Pairs

mcmc_pairs(fm, 
           pars = c("sigma", "sd_block__Intercept"),
           regex_pars = "^b_",
           np = nuts_params(fm))

Pairs

Traceplot

mcmc_trace(fm, 
           pars = c("sigma", "sd_block__Intercept"),
           regex_pars = "^b_",
           np = nuts_params(fm))

Traceplot

Traceplot

Zoom in on a specific region of samples.

mcmc_trace(fm, 
           pars = c("sigma", "sd_block__Intercept"),
           regex_pars = "^b_",
           np = nuts_params(fm),
           window = c(1800, 2000))

Traceplot

Increase adapt_delta

fm <- brm(zooplankton ~ treatment - 1 + (1 | block),
          data = D,
          seed = 4547359,
          iter = 5e3, chains = 4, cores = 4,
          control = list(adapt_delta = 0.99))
Warning messages:
1: There were 1 divergent transitions after warmup. See
https://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
to find out why this is a problem and how to eliminate them. 
2: Examine the pairs() plot to diagnose sampling problems

Examining the priors

prior_summary(fm)
                prior class             coef group resp dpar nlpar lb ub
               (flat)     b                                             
               (flat)     b treatmentcontrol                            
               (flat)     b    treatmenthigh                            
               (flat)     b     treatmentlow                            
 student_t(3, 0, 2.5)    sd                                         0   
 student_t(3, 0, 2.5)    sd                  block                  0   
 student_t(3, 0, 2.5)    sd        Intercept block                  0   
 student_t(3, 0, 2.5) sigma                                         0   
       source
      default
 (vectorized)
 (vectorized)
 (vectorized)
      default
 (vectorized)
 (vectorized)
      default

Examining the data

Regularizing priors

Means for each group (b parameters) are expected to be

  • within about 6 units of zero
  • positive (no negative zooplankton measurements)
priors <- prior(normal(0, 3), class = b, lb = 0)

fm <- brm(zooplankton ~ treatment - 1 + (1 | block),
          data = D,
          prior = priors,
          seed = 4547359,
          iter = 5e3, chains = 4, cores = 4,
          control = list(adapt_delta = 0.99))

Examining the priors

prior_summary(fm)
                prior class             coef group resp dpar nlpar lb ub
         normal(0, 3)     b                                         0   
         normal(0, 3)     b treatmentcontrol                        0   
         normal(0, 3)     b    treatmenthigh                        0   
         normal(0, 3)     b     treatmentlow                        0   
 student_t(3, 0, 2.5)    sd                                         0   
 student_t(3, 0, 2.5)    sd                  block                  0   
 student_t(3, 0, 2.5)    sd        Intercept block                  0   
 student_t(3, 0, 2.5) sigma                                         0   
       source
         user
 (vectorized)
 (vectorized)
 (vectorized)
      default
 (vectorized)
 (vectorized)
      default

Model summary

summary(fm)
 Family: gaussian 
  Links: mu = identity; sigma = identity 
Formula: zooplankton ~ treatment - 1 + (1 | block) 
   Data: D (Number of observations: 15) 
  Draws: 4 chains, each with iter = 5000; warmup = 2500; thin = 1;
         total post-warmup draws = 10000

Group-Level Effects: 
~block (Number of levels: 5) 
              Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(Intercept)     0.51      0.39     0.05     1.52 1.00     1569     2605

Population-Level Effects: 
                 Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
treatmentcontrol     2.95      0.37     2.12     3.65 1.00     1845     1567
treatmentlow         1.94      0.37     1.12     2.64 1.00     2152     1526
treatmenthigh        1.32      0.37     0.52     2.02 1.00     1976     1439

Family Specific Parameters: 
      Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma     0.54      0.15     0.33     0.90 1.00     2734     4127

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

Mean centering

  • stan is most efficient when sampling near zero
  • If you only care about difference and not absolute magnitude
  • Subtract the mean to center the data on zero
    • Can add the mean back later

Centered data

  • Don’t set the lower bound to 0
D <- D |> mutate(zooplankton_c = zooplankton - mean(zooplankton))

priors <- prior(normal(0, 2), class = b)

fm <- brm(zooplankton_c ~ treatment - 1 + (1 | block),
          data = D,
          prior = priors,
          seed = 4547359,
          iter = 5e3, chains = 4, cores = 4,
          control = list(adapt_delta = 0.99))

Model summary

summary(fm)
 Family: gaussian 
  Links: mu = identity; sigma = identity 
Formula: zooplankton_c ~ treatment - 1 + (1 | block) 
   Data: D (Number of observations: 15) 
  Draws: 4 chains, each with iter = 5000; warmup = 2500; thin = 1;
         total post-warmup draws = 10000

Group-Level Effects: 
~block (Number of levels: 5) 
              Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(Intercept)     0.47      0.35     0.03     1.35 1.00     1879     3023

Population-Level Effects: 
                 Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
treatmentcontrol     0.86      0.35     0.15     1.53 1.00     3505     3702
treatmentlow        -0.15      0.34    -0.85     0.53 1.00     3399     3507
treatmenthigh       -0.76      0.35    -1.46    -0.06 1.00     3246     3859

Family Specific Parameters: 
      Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma     0.55      0.16     0.33     0.93 1.00     2791     4485

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

Normalization

Convert data to Z-scores:

\[Z_i = \frac{(Y_i - \bar{Y})}{sd(Y)}\]

  • Mean is 0
  • Units are now standard deviations
    • More difficult to interpret
  • Values >3 are very uncommon
  • Useful when predictors are on very different scales